{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e16c97bd-b8d5-42dc-8e24-ce5938fc588c",
   "metadata": {},
   "source": [
    "# Gene essentiality prediction with Bacformer tutorial\n",
    "\n",
    "This tutorial outlines how one can finetune Bacformer model to prediction gene essentiality.\n",
    "\n",
    "We use a dataset from [Database of Essential Genes](http://origin.tubic.org/deg/public/index.php/browse/bacteria) as our training and evaluation set, evaluating the performance at the genome-level.\n",
    "\n",
    "Before you start, make sure you have `bacformer` installed (see README.md for details) and execute the notebook on a machine with GPU.\n",
    "\n",
    "## Step 1: Import required dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "37a9a064-66af-4078-b830-30c0b3a48b6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/maciejwiatrak/miniconda3/envs/bacformer-release/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING: faESM (fast ESM) not installed, this will lead to significant slowdown. Defaulting to use HuggingFace implementation. Please consider installing faESM: https://github.com/pengzhangzhi/faplm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from functools import partial\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from bacformer.modeling import (\n",
    "    SPECIAL_TOKENS_DICT,\n",
    "    BacformerTrainer,\n",
    "    collate_genome_samples,\n",
    "    compute_metrics_gene_essentiality_pred, adjust_prot_labels,\n",
    ")\n",
    "from bacformer.pp import dataset_col_to_bacformer_inputs\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoConfig, AutoModelForTokenClassification, EarlyStoppingCallback, TrainingArguments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d23e3c94-8e1c-4393-b8dc-4f9454ae40ed",
   "metadata": {},
   "source": [
    "## Step 2: Load the dataset\n",
    "\n",
    "We will be using the gene essentiality dataset preprocessed for this task. In this task, each protein in a genome has an `essentiality` label (binary) which we are predicting\n",
    "given 1) protein sequence itself, 2) the whole-genome context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa3ac440-2df8-444a-92de-93e03d7cf399",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the dataset from HuggingFace\n",
    "dataset = load_dataset(\"macwiatrak/bacbench-essential-genes-protein-sequences\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "065e57c6-7cf7-4e48-918e-6ff71066ec34",
   "metadata": {},
   "source": [
    "## Step 3: Embed the dataset with the base protein language model (pLM)\n",
    "\n",
    "The first step to using Bacformer is embedding the protein sequences with the base pLM model which is [ESM-2 t12 35M](https://huggingface.co/facebook/esm2_t12_35M_UR50D).\n",
    "\n",
    "This step takes ~5 min on a single A100 NVIDIA GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ced5e8-f20b-4b55-8f47-a830f5f82df0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embed the protein sequences with the ESM-2 base model and map the labels\n",
    "for split_name in dataset.keys():\n",
    "    dataset[split_name] = dataset_col_to_bacformer_inputs(\n",
    "        dataset=dataset[split_name],\n",
    "        max_n_proteins=7000,\n",
    "    )\n",
    "    # map the essentiality labels to a binary format\n",
    "    dataset[split_name] = dataset[split_name].map(\n",
    "        lambda row: adjust_prot_labels(\n",
    "            labels=row[\"essential\"],\n",
    "            special_tokens=row[\"special_tokens_mask\"],\n",
    "        ),\n",
    "        batched=False,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2932ee81-6dae-425f-bcf3-43ebdcf37a52",
   "metadata": {},
   "source": [
    "## Step 4: Load the Bacformer model\n",
    "\n",
    "Load the pre-trained Bacformer model with a classification layer on top which we finetune."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ea2b640-86ab-4a0d-aa0c-3efbfb851afe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the Bacformer model for protein classification\n",
    "# for this task we use the Bacformer model trained on masked complete genomes\n",
    "# with a token (here protein) classification head\n",
    "config = AutoConfig.from_pretrained(\"macwiatrak/bacformer-masked-complete-genomes\", trust_remote_code=True)\n",
    "config.num_labels = 1\n",
    "config.problem_type = \"binary_classification\"\n",
    "\n",
    "bacformer_model = AutoModelForTokenClassification.from_pretrained(\n",
    "    \"macwiatrak/bacformer-masked-complete-genomes\", config=config, trust_remote_code=True\n",
    ").to(torch.bfloat16)\n",
    "\n",
    "print(\"Nr of parameters:\", sum(p.numel() for p in bacformer_model.parameters()))\n",
    "print(\"Nr of trainable parameters:\", sum(p.numel() for p in bacformer_model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ff273ec-80e3-4b97-a8fc-6bde31369abe",
   "metadata": {},
   "source": [
    "## Step 5: Setup the trainer for finetuning\n",
    "\n",
    "Setup a trainer object to allow for finetuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae1dfc7-fef3-49e2-8f57-40aa6a3a195d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the output directory for the model and metrics\n",
    "output_dir = \"output/gene_essentiality_pred\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# create a trainer\n",
    "# get training args\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    save_total_limit=1,\n",
    "    learning_rate=0.00015,\n",
    "    num_train_epochs=100,\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    gradient_accumulation_steps=8,\n",
    "    seed=1,\n",
    "    dataloader_num_workers=4,\n",
    "    bf16=True,\n",
    "    metric_for_best_model=\"eval_macro_auroc\",\n",
    "    load_best_model_at_end=True,\n",
    "    greater_is_better=True,\n",
    ")\n",
    "\n",
    "# define a collate function for the dataset\n",
    "collate_genome_samples_fn = partial(collate_genome_samples, SPECIAL_TOKENS_DICT[\"PAD\"], 7000, 1000)\n",
    "trainer = BacformerTrainer(\n",
    "    model=bacformer_model,\n",
    "    data_collator=collate_genome_samples_fn,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"validation\"],\n",
    "    args=training_args,\n",
    "    compute_metrics=compute_metrics_gene_essentiality_pred,\n",
    "    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b0432eb-d575-46c1-ab45-d81a3052414a",
   "metadata": {},
   "source": [
    "## Step 6: Finetune the model 🎉🚂😎\n",
    "\n",
    "Finetune the model to predict gene essentiality. The training should take ~15 min on a single A100 NVIDIA GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cc64dc7-6726-434e-9556-de034d97a84b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1e89559-1d30-4653-aef0-614a6281d891",
   "metadata": {},
   "source": [
    "## Step 7: Evaluate on the test and run predictions\n",
    "\n",
    "Having a trained model you can now evaluate the model on the test set and run the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8787888-0df0-47ec-8b8e-db8349b260b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate the model on the test set\n",
    "test_output = trainer.predict(dataset[\"test\"])\n",
    "print(\"Test output:\", test_output.metrics)\n",
    "\n",
    "# get the predictions and labels for a single genome from the test set\n",
    "preds_strain = torch.sigmoid(torch.tensor(test_output.predictions.squeeze(-1)))[0]\n",
    "labels_strain = torch.tensor(test_output.label_ids)[0]\n",
    "genome_id = dataset[\"test\"][0][\"genome_id\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74d47526-6a58-4d51-9652-f5e561e466dc",
   "metadata": {},
   "source": [
    "## [Optional] Step 8: Plot gene essentiality probabilities\n",
    "\n",
    "Plot predicted gene essentiality probabilities for a single genome."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1376b8c5-b3fb-4e1d-9ce7-9b01cbd08cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make DF for plotting\n",
    "df = pd.DataFrame({'probability': preds_strain.tolist(), 'label': labels_strain.tolist()})\n",
    "# remove the ignore index rows\n",
    "df = df[df.label != -100]\n",
    "\n",
    "# Create the KDE plot\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.kdeplot(\n",
    "    data=df[df['label'] == 0],\n",
    "    x='probability',\n",
    "    fill=True,\n",
    "    color='blue',\n",
    "    alpha=0.6,\n",
    "    label='Non-essential genes'\n",
    ")\n",
    "sns.kdeplot(\n",
    "    data=df[df['label'] == 1],\n",
    "    x='probability',\n",
    "    fill=True,\n",
    "    color='goldenrod',\n",
    "    alpha=0.6,\n",
    "    label='Essential genes'\n",
    ")\n",
    "\n",
    "# Add legend and labels\n",
    "plt.title(f\"Gene Essentiality Prediction for {genome_id}\", fontsize=18)\n",
    "plt.legend(fontsize=20, title_fontsize=20, frameon=False, loc=\"upper left\")\n",
    "plt.xlabel(\"\", fontsize=12)\n",
    "plt.ylabel(\"\", fontsize=12)\n",
    "plt.title(\"\", fontsize=14)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.xlim(-0.1, 1.1)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e16a45c-5928-4b20-9e94-1cc4dfca44bc",
   "metadata": {},
   "source": [
    "----------------------\n",
    "\n",
    "#### Voilà, you made it 👏! \n",
    "\n",
    "In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e74f75a-2a7d-4c2e-992f-69f64c19ab76",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
